import gin
import torch
import spacy
import numpy as np
import os
import random
from sentence_transformers import SentenceTransformer

class cudaModel():
    def __init__(self, seed: int = -1, config: str = 'cuda.gin'):
        """
        Initializes the cudaModel instance with required components.
        Args:
            seed (int): Seed for randomization (-1 for random seed).
        """
        # Parse the Gin configuration file
        gin.parse_config_file(config)

        self.set_random_seed(seed)
        self.set_device()
        self.load_model()
        self.load_spacy_processor()


    @gin.configurable
    def set_random_seed(self, seed: int=42):
        """
        Helper function to seed experiment for reproducibility.
        If -1 is provided as seed, experiment uses random seed from 0~9999
        Args:
            seed (int): integer to be used as seed, use -1 to randomly seed experiment
        """
        if seed == -1:
            seed = random.randint(0, 9999)
        print("Seed: {}".format(seed))

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.enabled = False
        torch.backends.cudnn.deterministic = True

        random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    @gin.configurable
    def set_device(self, device=''):
        if not device:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        print(f"Using {self.device}")

    @gin.configurable
    def load_model(self, transformer_model='all-MiniLM-L6-v2'):
        self.model = SentenceTransformer(transformer_model)
        self.model = self.model.to(self.device)

    @gin.configurable
    def load_spacy_processor(self, spacy_model='en_core_web_md'):
        self.spacy_processor = spacy.load(spacy_model)

    def encode(self, text):
        """
        Encodes sentences using the loaded SentenceTransformer model.
        Args:
            sentences (list of str): List of sentences to encode.
        Returns:
            torch.Tensor: Tensor of encoded sentence embeddings.
        """
        vector_emb = self.model.encode(text, convert_to_tensor=True).to(self.device)
        return vector_emb

    @gin.configurable
    def extract_sentences(self, document, min_words=4):
        doc = self.spacy_processor(document)
        sentences = [sent.text for sent in doc.sents if len(sent.text.split()) > min_words]
        return sentences

    def get_device(self):
        return self.device